#!/usr/bin/env pmpython
"""Test pmdaopentelemetry via exposing fake endpoints -*- python -*- """
#
# Copyright (C) 2025 Red Hat.
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
# or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
# for more details.
#

import time
import re
from six.moves import queue
import threading
import socket
from six.moves import BaseHTTPServer
import argparse
import json
from copy import deepcopy

activeEndpoint = dict()


def write_endpoint_metadata(args=None, endpoint=None):
    endpointStr = "http://{}:{}/{}{}".format("localhost", args.addr[1], args.url, endpoint)
    f = open('{}/{}'.format(str(args.output), "source"+str(endpoint)+".url"), 'w')
    f.write("{}\n".format(endpointStr))
    f.close()

class FakeEndpoint(BaseHTTPServer.BaseHTTPRequestHandler):
    def sample_gauge_data(self, iteration=None, instances=None, metrics=None, endpointNum=None):
        gauge_value = 12.1
        instance_scale = 1
        instance_count = 0
        format_name = 'sample_gauge{:04d}'.format(endpointNum)
        format_desc = 'sample_gauge{:04d} instance scale {}, value scale {}'.format(endpointNum, instance_scale, gauge_value)
        gauge_string = {'name': 'sample_gauge', 'unit': '1', 'description': 'description', 'gauge': {'aggregationTemporality': '2', 'isMonotonic': 'True', 'dataPoints': [{'asDouble': '5', 'timeUnixNano': '1544712660300000000', 'attributes': [{'key': 'labels', 'value': {'stringValue': 'some.interesting.labels'}}]}]}}
        points = gauge_string['gauge']['dataPoints'][0]
        gauge_string['name'] = format_name
        gauge_string['description'] = format_desc
        data = []
        for i in range(instances):
            copy = deepcopy(points)
            format_instname = "{{bar=\'{:.1f}\'}}".format(i * instance_scale)
            format_value = "{:.1f}".format((int(iteration) * instance_count * gauge_value))
            inst_dict = {'key': 'instance', 'value': {'stringValue': format_instname}}
            copy['asDouble'] = format_value
            copy['attributes'].append(inst_dict)
            data.append(copy)
            instance_count += 1
        gauge_string['gauge']['dataPoints'] = data
        return gauge_string

    def sample_counter_data(self, iteration=None, instances=None, metrics=None, endpointNum=None):
        counter_value = 1.70205394e+08
        instance_scale = 0.7
        instance_count = 0
        format_name = 'sample_counter{:04d}'.format(endpointNum)
        format_desc = 'sample_counter{:04d} instance scale {} value scale {}'.format(endpointNum, instance_scale, counter_value)
        counter_string = {'name': 'sample_counter', 'unit': '1', 'description': 'description', 'sum': {'aggregationTemporality': '2', 'isMonotonic': 'True', 'dataPoints': [{'asDouble': 10, 'timeUnixNano': '1544712660300000000', 'attributes': [{'key': 'labels', 'value': {'stringValue': 'some.interesting.labels'}}]}]}}
        points = counter_string['sum']['dataPoints'][0]
        counter_string['name'] = format_name
        counter_string['description'] = format_desc
        data = []
        for i in range(instances):
            copy = deepcopy(points)
            format_instname = "{{baz=\'{:.1f}\'}}".format(i * instance_scale)
            format_value = "{:.8e}".format((int(iteration) * instance_count * counter_value))
            inst_dict = {'key': 'instance', 'value': {'stringValue': format_instname}}
            copy['asDouble'] = format_value
            copy['attributes'].append(inst_dict)
            data.append(copy)
            instance_count += 1
        counter_string['sum']['dataPoints'] = data
        return counter_string

    def sample_summary_data(self, iteration=None, instances=None, metrics=None, endpointNum=None):
        summary_value_q0 = 0.000159623
        summary_value_sum = 1.3912067700000001
        summary_value_count = 1818
        instance_scale = 0.25
        instance_count = 0
        format_name = 'sample_summary{:04d}'.format(endpointNum)
        format_desc = 'sample_summary{:04d} instance scale {} value scale {}'.format(endpointNum, instance_scale, summary_value_q0)
        format_count = '{:d}'.format((int(iteration) * int(summary_value_count)))
        format_sum = '{:16f}'.format((int(iteration) * float(summary_value_sum)))
        summary_string = {'name': 'my.summary', 'description': '', 'unit': '', 'summary': {'dataPoints': [{'attributes': [{'key': 'labels', 'value': {'stringValue': 'some.interesting.labels'}}], 'timeUnixNano': 1741190626268846852, 'count': 1, 'sum': 99.9, 'quantileValues': []}]}}
        points = summary_string['summary']['dataPoints'][0]
        summary_string['name'] = format_name
        summary_string['description'] = format_desc
        points["count"] = format_count
        points["sum"] = format_sum
        for i in range(instances):
            points["quantileValues"].append({'quantile': (i * instance_scale), 'value': (int(iteration) * instance_count * float(summary_value_q0))})
            instance_count += 1
        return summary_string

    def sample_histogram_data(self, iteration=None, instances=None, metrics=None, endpointNum=None):
        histogram_value_l1 = 5945
        histogram_value_sum = 45157
        histogram_value_count = 18135
        instance_scale = 2
        instance_count = 0
        format_name = 'sample_histogram{:04d}'.format(endpointNum)
        format_desc = 'sample_histogram{:04d} instance scale {} value scale {} (sample histogram has instances)'.format(endpointNum, instance_scale, histogram_value_l1)
        format_count = "{:d}".format((int(instances) * int(histogram_value_count)))
        format_sum = "{:d}".format((int(instances) * int(histogram_value_sum)))
        histogram_string = {'name': 'my.histogram', 'description': '', 'unit': '', 'histogram': {'dataPoints': [{'attributes': [{}], 'timeUnixNano': 1741190626268846852, 'count': 1, 'sum': 99.9, 'bucketCounts': [], 'explicitBounds': [], 'min': 99.9, 'max': 99.9, 'exemplars': []}], 'aggregationTemporality': 2}}
        points = histogram_string['histogram']['dataPoints'][0]
        histogram_string['name'] = format_name
        histogram_string['description'] = format_desc
        points["count"] = format_count
        points["sum"] = format_sum
        for i in range(instances):
            points["bucketCounts"].append((int(iteration) * instance_count * histogram_value_l1))
            points["explicitBounds"].append((i * instance_scale))
            instance_count += 1
        points["bucketCounts"].append(0)
        return histogram_string
 
    def start_string(self):
        start = {'resourceMetrics': [{'resource': {'attributes': [{'key': 'service.name', 'value': {'stringValue': 'my.service'}}]}, 'scopeMetrics': [{'scope': {'name': 'my.library', 'version': '1.0.0', 'attributes': [{'key': 'my.scope.attribute', 'value': {'stringValue': 'some.scope.attribute'}}]}, 'metrics': []}]}]}
        return start

    def format_opentelemetry_output(self, metrics=None, iteration=None, instances=None, error=None, endpointNum=None):
        metrics_remaining = metrics
        endpoint_string = self.start_string()
        helper = endpoint_string['resourceMetrics'][0]['scopeMetrics'][0]['metrics']

        while metrics_remaining >= 0:
            helper.append(self.sample_gauge_data(iteration, instances, metrics_remaining, (metrics-metrics_remaining)))
            metrics_remaining -= 1
            iteration += 1
            if metrics_remaining <= 0:
                break
            helper.append(self.sample_counter_data(iteration, instances, metrics_remaining, (metrics-metrics_remaining)))
            metrics_remaining -= 1
            iteration += 1
            if metrics_remaining <= 0:
                break
            helper.append(self.sample_summary_data(iteration, instances, metrics_remaining, (metrics-metrics_remaining)))
#            metrics_remaining -= 4
            metrics_remaining -= instances
            iteration += 1
            if metrics_remaining <= 0:
                break
            helper.append(self.sample_histogram_data(iteration, instances, metrics_remaining, (metrics-metrics_remaining)))
#            metrics_remaining -= 5
            metrics_remaining -= instances
            iteration += 1
        endpoint_string = json.dumps(endpoint_string, sort_keys=False, indent=4, separators=(',', ': '))
        if error is not None:
            return endpoint_string[:int(len(endpoint_string)/int(error))]
        else:
            return endpoint_string

    def do_GET(self):
        time.sleep(float(self.server.args.delay))
        if len(activeEndpoint) > 0:
            endpoint_regex = re.compile(str(self.server.args.url))
            iteration = re.split(endpoint_regex, self.path)
            if len(iteration) > 1 and int(iteration[1]) in activeEndpoint:
                endpointNum = int(iteration[1])
                if int(self.server.args.error) and not activeEndpoint[endpointNum] % int(self.server.args.error) and activeEndpoint[endpointNum]:
                    error = int(self.server.args.error)
                else:
                    error = None
                self.send_response(200)
                self.end_headers()
                self.wfile.write(self.format_opentelemetry_output(int(self.server.args.metrics), endpointNum*(activeEndpoint[endpointNum]), int(self.server.args.instances), error, endpointNum).encode())  # multiply by activeendpoint value or something?
                self.server.lock.acquire()
                try:
                    activeEndpoint[endpointNum] += 1
                    if activeEndpoint[endpointNum] >= int(self.server.args.limit):
                        del activeEndpoint[endpointNum] #do we even really need to do this?
                        self.server.pqueue.task_done()
                        try:
                            _next = self.server.pqueue.get_nowait()
                            activeEndpoint[int(_next)] = 0
                            write_endpoint_metadata(self.server.args, _next)
                        except queue.Empty:
                            pass
                finally:
                    self.server.lock.release()
            else:
                self.send_error(404)
        return

class PrometheusEndpoint(threading.Thread):
    def __init__(self, pqueue=None, args=None):
        threading.Thread.__init__(self)
        self.daemon = True
        self.pqueue = pqueue
        self.args = args
        self.lock = threading.Lock()
        self.server = None
        self.start()

    def run(self):
        try:
            endpoint = self.pqueue.get_nowait()
            self.lock.acquire()
            try:
                activeEndpoint[endpoint] = 0
                write_endpoint_metadata(self.args, endpoint)
            finally:
                self.lock.release()
        except queue.Empty:
            pass

        finally:
            httpd = BaseHTTPServer.HTTPServer(self.args.addr, FakeEndpoint, False)
            httpd.socket = self.args.sock
            httpd.pqueue = self.pqueue
            httpd.delay = self.args.delay
            httpd.args = self.args
            httpd.server_bind = self.server_close = lambda self: None
            httpd.lock = self.lock
            self.server = httpd
            self.server.serve_forever()

def parsing():
    parser = argparse.ArgumentParser(description='Setup a number of fake opentelemetry endpoints for collection.')
    parser.add_argument('--endpoints', default=2, help='number of opentelemetry endpoints to start at once')
    parser.add_argument('--metrics', default=5, help='number of metrics per opentelemetry endpoint')
    parser.add_argument('--instances', default=5, help='number of instances per metric')
    parser.add_argument('--delay', default=0,  help='delay time (seconds) for each "slow" node')
    parser.add_argument('--limit', default=5, help='number of iterations/responses each endpoint limits itself to')
    parser.add_argument('--total', default=5, help='total number of endpoints to run')
    parser.add_argument('--port', default=10000, help='port to start fake endpoints on')
    parser.add_argument('--error', default=0, help='create errornous data in opentelemetry endpoints')
    parser.add_argument('--url', default="foo&endpoint=", help='url for the endpoint, the endpoint number must be the last part')
    parser.add_argument('--output', default="/tmp", help='directory to create endpoint metadata in')
    args = parser.parse_args()
    return args

if __name__ == '__main__':

    args = parsing()
    pendpointQueue = queue.Queue()
    addr = ('', int(args.port))
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    sock.bind(addr)
    sock.listen(int(args.port))
    args.addr = addr
    args.sock = sock

    for endpoint in range(int(args.total)):
        pendpointQueue.put(endpoint)

    pendpoints = [PrometheusEndpoint(pendpointQueue, args)
                  for i in range(int(args.endpoints))]

    pendpointQueue.join()
    sock.close()
